import copy
from functools import partial, wraps
from math import sqrt

import flowvision
import oneflow as flow
import oneflow.nn.functional as F
from einops import rearrange, repeat
from oneflow import einsum, nn
from oneflow.autograd import grad as flow_grad

from libai.layers import Linear

from .einops_exts import Rearrange, rearrange_many
from .vector_quantize_flow import VectorQuantize as VQ

# constants

MList = nn.ModuleList

# helper functions


def exists(val):
    return val is not None


def default(val, d):
    return val if exists(val) else d


# decorators


def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out

    return inner


def remove_vgg(fn):
    @wraps(fn)
    def inner(self, *args, **kwargs):
        has_vgg = hasattr(self, "vgg")
        if has_vgg:
            vgg = self.vgg
            delattr(self, "vgg")

        out = fn(self, *args, **kwargs)

        if has_vgg:
            self.vgg = vgg

        return out

    return inner


# keyword argument helpers


def pick_and_pop(keys, d):
    values = list(map(lambda key: d.pop(key), keys))
    return dict(zip(keys, values))


def group_dict_by_key(cond, d):
    return_val = [dict(), dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val,)


def string_begins_with(prefix, string_input):
    return string_input.startswith(prefix)


def group_by_key_prefix(prefix, d):
    return group_dict_by_key(partial(string_begins_with, prefix), d)


def groupby_prefix_and_trim(prefix, d):
    kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
    kwargs_without_prefix = dict(
        map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))
    )
    return kwargs_without_prefix, kwargs


# tensor helper functions


def log(t, eps=1e-10):
    return flow.log(t + eps)


def gradient_penalty(images, output, weight=10):
    images.shape[0]
    gradients = flow_grad(
        outputs=output,
        inputs=images,
        grad_outputs=flow.ones(output.size(), device=images.device),
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    gradients = rearrange(gradients, "b ... -> b (...)")
    return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()


def l2norm(t):
    return F.normalize(t, dim=-1)


def leaky_relu(p=0.1):
    return nn.LeakyReLU(0.1)


def stable_softmax(t, dim=-1, alpha=32 ** 2):
    t = t / alpha
    t = t - flow.amax(t, dim=dim, keepdim=True).detach()
    return (t * alpha).softmax(dim=dim)


def safe_div(numer, denom, eps=1e-8):
    return numer / (denom + eps)


# gan losses


def hinge_discr_loss(fake, real):
    return (F.relu(1 + fake) + F.relu(1 - real)).mean()


def hinge_gen_loss(fake):
    return -fake.mean()


def bce_discr_loss(fake, real):
    return (-log(1 - flow.sigmoid(fake)) - log(flow.sigmoid(real))).mean()


def bce_gen_loss(fake):
    return -log(flow.sigmoid(fake)).mean()


def grad_layer_wrt_loss(loss, layer):
    return flow_grad(
        outputs=loss, inputs=layer, grad_outputs=flow.ones_like(loss), retain_graph=True
    )[0].detach()


# vqgan vae


class LayerNormChan(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(flow.ones(1, dim, 1, 1))

    def forward(self, x):
        var = flow.var(x, dim=1, unbiased=False, keepdim=True)
        mean = flow.mean(x, dim=1, keepdim=True)
        return (x - mean) / (var + self.eps).sqrt() * self.gamma


# discriminator


class Discriminator(nn.Module):
    def __init__(self, dims, channels=3, groups=16, init_kernel_size=5):
        super().__init__()
        dim_pairs = zip(dims[:-1], dims[1:])

        self.layers = MList(
            [
                nn.Sequential(
                    nn.Conv2d(channels, dims[0], init_kernel_size, padding=init_kernel_size // 2),
                    leaky_relu(),
                )
            ]
        )

        for dim_in, dim_out in dim_pairs:
            self.layers.append(
                nn.Sequential(
                    nn.Conv2d(dim_in, dim_out, 4, stride=2, padding=1),
                    nn.GroupNorm(groups, dim_out),
                    leaky_relu(),
                )
            )

        dim = dims[-1]
        self.to_logits = nn.Sequential(  # return 5 x 5, for PatchGAN-esque training
            nn.Conv2d(dim, dim, 1), leaky_relu(), nn.Conv2d(dim, 1, 4)
        )

    def forward(self, x):
        for net in self.layers:
            x = net(x)

        return self.to_logits(x)


# positional encoding


class ContinuousPositionBias(nn.Module):
    """from https://arxiv.org/abs/2111.09883"""

    def __init__(self, *, dim, heads, layers=2):
        super().__init__()
        self.net = MList([])
        self.net.append(nn.Sequential(Linear(2, dim), leaky_relu()))

        for _ in range(layers - 1):
            self.net.append(nn.Sequential(Linear(dim, dim), leaky_relu()))

        self.net.append(Linear(dim, heads))
        self.register_buffer("rel_pos", None, persistent=False)

    def forward(self, x):
        n, device = x.shape[-1], x.device
        fmap_size = int(sqrt(n))

        if not exists(self.rel_pos):
            pos = flow.arange(fmap_size, device=device)
            grid = flow.stack(flow.meshgrid(pos, pos, indexing="ij"))
            grid = rearrange(grid, "c i j -> (i j) c")
            rel_pos = rearrange(grid, "i c -> i 1 c") - rearrange(grid, "j c -> 1 j c")
            rel_pos = flow.sign(rel_pos) * flow.log(rel_pos.abs() + 1)
            self.register_buffer("rel_pos", rel_pos, persistent=False)

        rel_pos = self.rel_pos.float()

        for layer in self.net:
            rel_pos = layer(rel_pos)

        bias = rearrange(rel_pos, "i j h -> h i j")
        return x + bias


# resnet encoder / decoder


class ResnetEncDec(nn.Module):
    def __init__(
        self,
        dim,
        *,
        channels=3,
        layers=4,
        layer_mults=None,
        num_resnet_blocks=1,
        resnet_groups=16,
        first_conv_kernel_size=5,
        use_attn=True,
        attn_dim_head=64,
        attn_heads=8,
        attn_dropout=0.0,
    ):
        super().__init__()
        assert (
            dim % resnet_groups == 0
        ), f"dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)"

        self.layers = layers

        self.encoders = MList([])
        self.decoders = MList([])

        layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
        assert (
            len(layer_mults) == layers
        ), "layer multipliers must be equal to designated number of layers"

        layer_dims = [dim * mult for mult in layer_mults]
        dims = (dim, *layer_dims)

        self.encoded_dim = dims[-1]

        dim_pairs = zip(dims[:-1], dims[1:])

        def append(arr, t):
            arr.append(t)

        def prepend(arr, t):
            arr.insert(0, t)

        if not isinstance(num_resnet_blocks, tuple):
            num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)

        if not isinstance(use_attn, tuple):
            use_attn = (*((False,) * (layers - 1)), use_attn)

        assert (
            len(num_resnet_blocks) == layers
        ), "number of resnet blocks config must be equal to number of layers"
        assert len(use_attn) == layers

        for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(
            range(layers), dim_pairs, num_resnet_blocks, use_attn
        ):
            append(
                self.encoders,
                nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride=2, padding=1), leaky_relu()),
            )
            prepend(
                self.decoders,
                nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()),
            )

            if layer_use_attn:
                prepend(
                    self.decoders,
                    VQGanAttention(
                        dim=dim_out, heads=attn_heads, dim_head=attn_dim_head, dropout=attn_dropout
                    ),
                )

            for _ in range(layer_num_resnet_blocks):
                append(self.encoders, ResBlock(dim_out, groups=resnet_groups))
                prepend(self.decoders, GLUResBlock(dim_out, groups=resnet_groups))

            if layer_use_attn:
                append(
                    self.encoders,
                    VQGanAttention(
                        dim=dim_out, heads=attn_heads, dim_head=attn_dim_head, dropout=attn_dropout
                    ),
                )

        prepend(
            self.encoders,
            nn.Conv2d(channels, dim, first_conv_kernel_size, padding=first_conv_kernel_size // 2),
        )
        append(self.decoders, nn.Conv2d(dim, channels, 1))

    def get_encoded_fmap_size(self, image_size):
        return image_size // (2 ** self.layers)

    @property
    def last_dec_layer(self):
        return self.decoders[-1].weight

    def encode(self, x):
        for enc in self.encoders:
            x = enc(x)
        return x

    def decode(self, x):
        for dec in self.decoders:
            x = dec(x)
        return x


class GLUResBlock(nn.Module):
    def __init__(self, chan, groups=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(chan, chan * 2, 3, padding=1),
            nn.GLU(dim=1),
            nn.GroupNorm(groups, chan),
            nn.Conv2d(chan, chan * 2, 3, padding=1),
            nn.GLU(dim=1),
            nn.GroupNorm(groups, chan),
            nn.Conv2d(chan, chan, 1),
        )

    def forward(self, x):
        return self.net(x) + x


class ResBlock(nn.Module):
    def __init__(self, chan, groups=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(chan, chan, 3, padding=1),
            nn.GroupNorm(groups, chan),
            leaky_relu(),
            nn.Conv2d(chan, chan, 3, padding=1),
            nn.GroupNorm(groups, chan),
            leaky_relu(),
            nn.Conv2d(chan, chan, 1),
        )

    def forward(self, x):
        return self.net(x) + x


# vqgan attention layer


class VQGanAttention(nn.Module):
    def __init__(self, *, dim, dim_head=64, heads=8, dropout=0.0):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = heads * dim_head

        self.dropout = nn.Dropout(dropout)
        self.pre_norm = LayerNormChan(dim)

        self.cpb = ContinuousPositionBias(dim=dim // 4, heads=heads)
        self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(inner_dim, dim, 1, bias=False)

    def forward(self, x):
        h = self.heads
        height, width, residual = *x.shape[-2:], x.clone()

        x = self.pre_norm(x)

        q, k, v = self.to_qkv(x).chunk(3, dim=1)

        q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=h), (q, k, v))

        sim = einsum("b h c i, b h c j -> b h i j", q, k) * self.scale

        sim = self.cpb(sim)

        attn = stable_softmax(sim, dim=-1)
        attn = self.dropout(attn)

        out = einsum("b h i j, b h c j -> b h c i", attn, v)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", x=height, y=width)
        out = self.to_out(out)

        return out + residual


# ViT encoder / decoder


class RearrangeImage(nn.Module):
    def forward(self, x):
        n = x.shape[1]
        w = h = int(sqrt(n))
        return rearrange(x, "b (h w) ... -> b h w ...", h=h, w=w)


class Attention(nn.Module):
    def __init__(self, dim, *, heads=8, dim_head=32):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.to_out = nn.Linear(inner_dim, dim)

    def forward(self, x):
        h = self.heads

        x = self.norm(x)

        q, k, v = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)

        q = q * self.scale
        sim = einsum("b h i d, b h j d -> b h i j", q, k)

        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)

        out = einsum("b h i j, b h j d -> b h i d", attn, v)

        out = rearrange(out, "b h n d -> b n (h d)")
        return self.to_out(out)


def FeedForward(dim, mult=4):
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, dim * mult, bias=False),
        nn.GELU(),
        nn.Linear(dim * mult, dim, bias=False),
    )


class Transformer(nn.Module):
    def __init__(self, dim, *, layers, dim_head=32, heads=8, ff_mult=4):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(layers):
            self.layers.append(
                nn.ModuleList(
                    [
                        Attention(dim=dim, dim_head=dim_head, heads=heads),
                        FeedForward(dim=dim, mult=ff_mult),
                    ]
                )
            )

        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)


class ViTEncDec(nn.Module):
    def __init__(self, dim, channels=3, layers=4, patch_size=8, dim_head=32, heads=8, ff_mult=4):
        super().__init__()
        self.encoded_dim = dim
        self.patch_size = patch_size

        input_dim = channels * (patch_size ** 2)

        self.encoder = nn.Sequential(
            Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=patch_size, p2=patch_size),
            Linear(input_dim, dim),
            Transformer(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult, layers=layers),
            RearrangeImage(),
            Rearrange("b h w c -> b c h w"),
        )

        self.decoder = nn.Sequential(
            Rearrange("b c h w -> b (h w) c"),
            Transformer(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult, layers=layers),
            nn.Sequential(
                Linear(dim, dim * 4, bias=False),
                nn.Tanh(),
                Linear(dim * 4, input_dim, bias=False),
            ),
            RearrangeImage(),
            Rearrange("b h w (p1 p2 c) -> b c (h p1) (w p2)", p1=patch_size, p2=patch_size),
        )

    def get_encoded_fmap_size(self, image_size):
        return image_size // self.patch_size

    @property
    def last_dec_layer(self):
        return self.decoder[-3][-1].weight

    def encode(self, x):
        return self.encoder(x)

    def decode(self, x):
        return self.decoder(x)


# main vqgan-vae classes


class NullVQGanVAE(nn.Module):
    def __init__(self, *, channels):
        super().__init__()
        self.encoded_dim = channels
        self.layers = 0

    def get_encoded_fmap_size(self, size):
        return size

    def copy_for_eval(self):
        return self

    def encode(self, x):
        return x

    def decode(self, x):
        return x


class VQGanVAE(nn.Module):
    def __init__(
        self,
        *,
        dim,
        image_size,
        channels=3,
        layers=4,
        l2_recon_loss=False,
        use_hinge_loss=True,
        vgg=None,
        vq_codebook_dim=256,
        vq_codebook_size=512,
        vq_decay=0.8,
        vq_commitment_weight=1.0,
        vq_kmeans_init=True,
        vq_use_cosine_sim=True,
        use_vgg_and_gan=True,
        vae_type="resnet",
        discr_layers=4,
        **kwargs,
    ):
        super().__init__()
        vq_kwargs, kwargs = groupby_prefix_and_trim("vq_", kwargs)
        encdec_kwargs, kwargs = groupby_prefix_and_trim("encdec_", kwargs)

        self.image_size = image_size
        self.channels = channels
        self.codebook_size = vq_codebook_size

        if vae_type == "resnet":
            enc_dec_klass = ResnetEncDec
        elif vae_type == "vit":
            enc_dec_klass = ViTEncDec
        else:
            raise ValueError(f"{vae_type} not valid")

        self.enc_dec = enc_dec_klass(dim=dim, channels=channels, layers=layers, **encdec_kwargs)

        self.vq = VQ(
            dim=self.enc_dec.encoded_dim,
            codebook_dim=vq_codebook_dim,
            codebook_size=vq_codebook_size,
            decay=vq_decay,
            commitment_weight=vq_commitment_weight,
            accept_image_fmap=True,
            kmeans_init=vq_kmeans_init,
            use_cosine_sim=vq_use_cosine_sim,
            **vq_kwargs,
        )

        # reconstruction loss

        self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss

        # turn off GAN and perceptual loss if grayscale

        self.vgg = None
        self.discr = None
        self.use_vgg_and_gan = use_vgg_and_gan

        if not use_vgg_and_gan:
            return

        # preceptual loss

        if exists(vgg):
            self.vgg = vgg
        else:
            self.vgg = flowvision.models.vgg16(pretrained=True)
            self.vgg.classifier = nn.Sequential(*self.vgg.classifier[:-2])

        # gan related losses

        layer_mults = list(map(lambda t: 2 ** t, range(discr_layers)))
        layer_dims = [dim * mult for mult in layer_mults]
        dims = (dim, *layer_dims)

        self.discr = Discriminator(dims=dims, channels=channels)

        self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
        self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss

    @property
    def encoded_dim(self):
        return self.enc_dec.encoded_dim

    def get_encoded_fmap_size(self, image_size):
        return self.enc_dec.get_encoded_fmap_size(image_size)

    def copy_for_eval(self):
        device = next(self.parameters()).device
        vae_copy = copy.deepcopy(self.cpu())

        if vae_copy.use_vgg_and_gan:
            del vae_copy.discr
            del vae_copy.vgg

        vae_copy.eval()
        return vae_copy.to(device)

    @remove_vgg
    def state_dict(self, *args, **kwargs):
        return super().state_dict(*args, **kwargs)

    @remove_vgg
    def load_state_dict(self, *args, **kwargs):
        return super().load_state_dict(*args, **kwargs)

    @property
    def codebook(self):
        return self.vq.codebook

    def encode(self, fmap):
        fmap = self.enc_dec.encode(fmap)
        return fmap

    def decode(self, fmap, return_indices_and_loss=False):
        fmap, indices, commit_loss = self.vq(fmap)

        fmap = self.enc_dec.decode(fmap)

        if not return_indices_and_loss:
            return fmap

        return fmap, indices, commit_loss

    def forward(
        self,
        img,
        return_loss=False,
        return_discr_loss=False,
        return_recons=False,
        add_gradient_penalty=True,
    ):
        _, channels, height, width, _ = *img.shape, img.device
        assert (
            height == self.image_size and width == self.image_size
        ), "height and width of input image must be equal to {self.image_size}"
        assert (
            channels == self.channels
        ), "number of channels on image or sketch is not equal to the channels set on this VQGanVAE"

        fmap = self.encode(img)

        fmap, indices, commit_loss = self.decode(fmap, return_indices_and_loss=True)

        if not return_loss and not return_discr_loss:
            return fmap

        assert (
            return_loss ^ return_discr_loss
        ), "you should either return autoencoder loss or discriminator loss, but not both"

        # whether to return discriminator loss

        if return_discr_loss:
            assert exists(self.discr), "discriminator must exist to train it"

            fmap.detach_()
            img.requires_grad_()

            fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))

            discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)

            if add_gradient_penalty:
                gp = gradient_penalty(img, img_discr_logits)
                loss = discr_loss + gp

            if return_recons:
                return loss, fmap

            return loss

        # reconstruction loss

        recon_loss = self.recon_loss_fn(fmap, img)

        # early return if training on grayscale

        if not self.use_vgg_and_gan:
            if return_recons:
                return recon_loss, fmap

            return recon_loss

        # perceptual loss

        img_vgg_input = img
        fmap_vgg_input = fmap

        if img.shape[1] == 1:
            # handle grayscale for vgg
            img_vgg_input, fmap_vgg_input = map(
                lambda t: repeat(t, "b 1 ... -> b c ...", c=3), (img_vgg_input, fmap_vgg_input)
            )

        img_vgg_feats = self.vgg(img_vgg_input)
        recon_vgg_feats = self.vgg(fmap_vgg_input)
        perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)

        # generator loss

        gen_loss = self.gen_loss(self.discr(fmap))

        # calculate adaptive weight

        last_dec_layer = self.enc_dec.last_dec_layer

        norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p=2)
        norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(
            p=2
        )

        adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)
        adaptive_weight.clamp_(max=1e4)

        # combine losses

        loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss

        if return_recons:
            return loss, fmap

        return loss
